Denoising Diffusion Implicit Models for the Oxford Flowers 102 Dataset¶

Introduction¶

Oxford Flowers 102 Dataset¶

https://www.robots.ox.ac.uk/~vgg/data/flowers/102/

102 category dataset, consisting of 102 flower categories. The flowers chosen to be flower commonly occuring in the United Kingdom. Each class consists of between 40 and 258 images. The images have large scale, pose and light variations. In addition, there are categories that have large variations within the category and several very similar categories.

There are a total of 8189 images, a total download in ~330Mb.

More information in the TensorFlow Datasets page:

https://www.tensorflow.org/datasets/catalog/oxford_flowers102

The package github.com/gomlx/gomlx/examples/oxfordflowers102 provides a train.Dataset object that can be used to train models with the dataset. It also provides a simplified mechanism to download and cache the dataset.

Denoising Diffusion Implicit Models¶

This notebook is an example of a diffusion model, bsed on he Keras example in:

https://keras.io/examples/generative/ddim/

The modeling and training code is in github.com/gomlx/gomlx/examples/oxfordflowers102/diffusion. There is a training binary for the commandline in the train/ subdirectory, it's a very small wrapper of the diffusion library.

In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx" "${HOME}/Projects/gonb" "${HOME}/Projects/gopjrt" "${HOME}/Projects/bsplines"
%goworkfix
	- Added replace rule for module "github.com/gomlx/gopjrt" to local directory "/home/janpf/Projects/gopjrt".
	- Added replace rule for module "github.com/gomlx/bsplines" to local directory "/home/janpf/Projects/bsplines".
	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".

Hyperparameters and Configuration¶

Some basic parameters are set as flags, but everything else are set as parameters in the context.Contex. The hyperparameters can be configured using the --set command-line flag.

Below we define the ContextFromSettings that we are going to use everywhere, and we print out the available hyperparameters.

See diffusion.CreateDefaultContext for documentation on all hyperparameters.

In [2]:
import (
    "github.com/gomlx/gomlx/examples/oxfordflowers102/diffusion"
    "github.com/gomlx/gomlx/ml/context"

    // Use XLA backend.
    _ "github.com/gomlx/gomlx/backends/xla"
)

var (
	flagDataDir    = flag.String("data", "~/work/oxfordflowers102", "Directory to cache downloaded and generated dataset files.")
	flagEval       = flag.Bool("eval", true, "Whether to evaluate the model on the validation data in the end.")
	flagVerbosity  = flag.Int("verbosity", 1, "Level of verbosity, the higher the more verbose.")
	flagCheckpoint = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")

    // settings is bound to a "-set" flag to be used to set context hyperparameters.
    settings = commandline.CreateContextSettingsFlag(imdb.CreateDefaultContext(), "set")
)

// ContextFromSettings is the default context (createDefaultContext) changed by -set flag.
// It requires that flags are already parsed.
//
// It also returns the list of parameters that were set.
func ContextFromSettings() (ctx *context.Context, paramsSet[]string) {
    ctx = diffusion.CreateDefaultContext()
    paramsSet = must.M1(commandline.ParseContextSettings(ctx, *settings))
    return
}

// ConfigFromSettings returns a diffusion.Config object initilized from the settings.
// It requires that flags are already parsed.
func ConfigFromFlags() *diffusion.Config {
    backend := backends.New()
    ctx, paramsSet := ContextFromSettings()
    return diffusion.NewConfig(backend, ctx, *flagDataDir, paramsSet)
}

%% -set="dtype=float32;train_steps=1_000"
c := ConfigFromFlags()
fmt.Println(commandline.SprintContextSettings(c.Context))
Context hyperparameters:
	"activation": (string) swish
	"adam_dtype": (string) 
	"adam_epsilon": (float64) 1e-07
	"adam_weight_decay": (float64) 0.0001
	"batch_size": (int) 32
	"checkpoint_frequency": (string) 3m
	"cosine_annealing_min_learning_rate": (float64) 1e-05
	"cosine_schedule_steps": (int) 0
	"diffusion_balanced_dataset": (bool) false
	"diffusion_channels_list": ([]int) [32 64 96 128]
	"diffusion_ema": (float64) 0.999
	"diffusion_loss": (string) mae
	"diffusion_max_signal_ratio": (float64) 0.95
	"diffusion_min_signal_ratio": (float64) 0.02
	"diffusion_num_residual_blocks": (int) 4
	"diffusion_pool": (string) mean
	"diffusion_residual_version": (int) 2
	"dropout_rate": (float64) 0.15
	"dtype": (string) float32
	"eval_batch_size": (int) 128
	"flower_type_embed_size": (int) 16
	"huber_delta": (float64) 0.2
	"image_size": (int) 64
	"kid": (bool) false
	"l1_regularization": (float64) 0
	"l2_regularization": (float64) 0
	"learning_rate": (float64) 0.001
	"nan_logger": (bool) false
	"normalization": (string) layer
	"num_checkpoints": (int) 5
	"optimizer": (string) adam
	"plots": (bool) true
	"rng_reset": (bool) true
	"samples_during_training": (int) 64
	"samples_during_training_frequency": (int) 200
	"samples_during_training_frequency_growth": (float64) 1.2
	"sinusoidal_embed_size": (int) 32
	"sinusoidal_max_freq": (float64) 1000
	"sinusoidal_min_freq": (float64) 1
	"train_steps": (int) 1000
	"use_ema": (bool) false

Dataset Management¶

Data directory and downloading the dataset¶

Bellow we call DownloadAndParse: it will only download if dataset files are not yet downloaded, otherwise it's a no-op.

This cell also defines some flags we will use later everywhere.

In [3]:
import (
    "flag"
    flowers "github.com/gomlx/gomlx/examples/oxfordflowers102"
    "github.com/janpfeifer/must"
)

%%
c := ConfigFromFlags()
must.M(flowers.DownloadAndParse(c.DataDir))
fmt.Println("Oxford Flowers 102 dataset downloaded:")
fmt.Printf("\t%d images, %d labels, %d examples\n", len(flowers.AllImages), len(flowers.AllLabels), flowers.NumExamples)
Oxford Flowers 102 dataset downloaded:
	8189 images, 8189 labels, 8189 examples

Sample of Flowers¶

To do that we create a temporry dataset (with NewDataset) of size 256x256 pixels, and then show a sample of the flowers.

Later we will use a model that uses only 64x64 pixels.

In [4]:
import (
    timage "github.com/gomlx/gomlx/pkg/core/tensors/images"
)

// sampleTable generates and outputs one html table of samples, sampling rows x cols from the images/labels provided.
func sampleTable(title string, ds train.Dataset, rows, cols int) {
    htmlRows := make([]string, 0, rows)
    for row := 0; row < rows; row++ {
        cells := make([]string, 0, cols)
        for col := 0; col < cols; col++ {
            cells = append(cells, sampleOneImage(ds))
        }
        htmlRows = append(htmlRows, fmt.Sprintf("<tr>\n\t<td>%s</td>\n</tr>", strings.Join(cells, "</td>\n\t<td>")))
    }
    htmlTable := fmt.Sprintf("<h4>%s</h4><table>%s</table>\n", title, strings.Join(htmlRows, ""))
    gonbui.DisplayHTML(htmlTable)
}

// sampleOneImage one image from tensor and returns an HTML rendered image with label
func sampleOneImage(ds train.Dataset) string {
    _, inputs, labels := must.M3(ds.Yield())
    imgTensor := inputs[0]
    img := timage.ToImage().Single(imgTensor)
    exampleNum := inputs[1].Value().(int64)
    label := labels[0].Value().(int32)
    labelStr := flowers.Names[label]
    
    imgSrc := must.M1(gonbui.EmbedImageAsPNGSrc(img))
    size := imgTensor.Shape().Dimensions[0]
    return fmt.Sprintf(`<figure style="padding:4px;text-align: center;"><img width="%d" height="%d" src="%s"><figcaption style="text-align: center;">Example %d:<br/><span>%s (%d)</span><br/>(%dx%d pixels)</figcaption></figure>`, 
                       size, size, imgSrc, exampleNum, labelStr, label, img.Bounds().Dx(), img.Bounds().Dy())
}

%% --set="image_size=256"
c := ConfigFromFlags()
must.M(flowers.DownloadAndParse(c.DataDir))
ds := flowers.NewDataset(dtypes.U8, c.ImageSize)
ds.Shuffle()
sampleTable("Oxford 102 Flowers Sample", ds, 4, 6)

Oxford 102 Flowers Sample

No description has been provided for this image
Example 6065:
spear thistle (13)
(256x256 pixels)
No description has been provided for this image
Example 4973:
marigold (46)
(256x256 pixels)
No description has been provided for this image
Example 3124:
snapdragon (10)
(256x256 pixels)
No description has been provided for this image
Example 7290:
cape flower (36)
(256x256 pixels)
No description has been provided for this image
Example 2350:
sword lily (42)
(256x256 pixels)
No description has been provided for this image
Example 7194:
tiger lily (5)
(256x256 pixels)
No description has been provided for this image
Example 6611:
canterbury bells (2)
(256x256 pixels)
No description has been provided for this image
Example 4702:
pelargonium (54)
(256x256 pixels)
No description has been provided for this image
Example 3038:
bee balm (91)
(256x256 pixels)
No description has been provided for this image
Example 1219:
rose (73)
(256x256 pixels)
No description has been provided for this image
Example 6067:
spear thistle (13)
(256x256 pixels)
No description has been provided for this image
Example 1818:
hibiscus (82)
(256x256 pixels)
No description has been provided for this image
Example 527:
cyclamen (87)
(256x256 pixels)
No description has been provided for this image
Example 3464:
sweet william (29)
(256x256 pixels)
No description has been provided for this image
Example 2370:
sword lily (42)
(256x256 pixels)
No description has been provided for this image
Example 6376:
yellow iris (14)
(256x256 pixels)
No description has been provided for this image
Example 2061:
anthurium (79)
(256x256 pixels)
No description has been provided for this image
Example 1413:
petunia (50)
(256x256 pixels)
No description has been provided for this image
Example 4266:
peruvian lily (17)
(256x256 pixels)
No description has been provided for this image
Example 6800:
fire lily (20)
(256x256 pixels)
No description has been provided for this image
Example 2658:
geranium (57)
(256x256 pixels)
No description has been provided for this image
Example 991:
wallflower (45)
(256x256 pixels)
No description has been provided for this image
Example 2620:
columbine (83)
(256x256 pixels)
No description has been provided for this image
Example 1252:
rose (73)
(256x256 pixels)

In-Memory Dataset for Fast Access¶

We convert the flowers dataset to InMemory, and cache its contents for faster start-up time.

The first time it runs it will read and convert all images to the target size. But it then saves a cache of the generated content, so the second time it is faster.

From a local benchmark (go test -bench=. -test.run=Benchmark, with --batch=64):

  • Directly reading (and parsing) from disk: ~215 ms/batch.
  • Parallelized (24 cores) reading form disk: ~25 ms/batch.
  • InMemory batches in GPU: ~41 µs/batch.
In [5]:
// Remove cached file to force regeneratrion.
!rm -f "${HOME}/work/oxfordflowers102/"*_cached_images_*

%%
c := ConfigFromFlags()
trainDS, validationDS := c.CreateInMemoryDatasets()

fmt.Println()
fmt.Printf("Total number of examples: #train=%d, #validation=%d\n", trainDS.NumExamples(), validationDS.NumExamples())
fmt.Printf("trainDS (in-memory) using %s of memory.\n", data.ByteCountIEC(trainDS.Memory()))
fmt.Printf("validationDS (in-memory) using %s of memory.\n", data.ByteCountIEC(validationDS.Memory()))

// Output a random sample.
trainDS.Shuffle()
sampleTable("Oxford 102 Flowers Sample -- In-Memory Dataset", trainDS, 1, 6)
Creating InMemoryDataset for "train" with images cropped and scaled to 64x64...
	- 4.393288377s to process dataset.
Creating InMemoryDataset for "validation" with images cropped and scaled to 64x64...
	- 1.467081864s to process dataset.

Total number of examples: #train=6487, #validation=1702
trainDS (in-memory) using 76.1 MiB of memory.
validationDS (in-memory) using 20.0 MiB of memory.

Oxford 102 Flowers Sample -- In-Memory Dataset

No description has been provided for this image
Example 5973:
windflower (68)
(64x64 pixels)
No description has been provided for this image
Example 5743:
king protea (12)
(64x64 pixels)
No description has been provided for this image
Example 3658:
primula (52)
(64x64 pixels)
No description has been provided for this image
Example 4170:
wild pansy (51)
(64x64 pixels)
No description has been provided for this image
Example 1417:
petunia (50)
(64x64 pixels)
No description has been provided for this image
Example 4636:
buttercup (47)
(64x64 pixels)

Denoising Diffusion Implicit Model¶

Preprocessing of images¶

The diffusion model takes images in normalized to a mean of 0 and standard deviation of 1, and generates images in the same range.

The functions PreprocessImage and DenormalizeImage converts to floats and normalize/denormalize them.

Below we quickly test that calling PreprocessImage and then DenormalizeImage has no effect on a random batch of images.

In [6]:
import (
    . "github.com/gomlx/gomlx/pkg/core/graph"
    "github.com/gomlx/gomlx/ml/context"
    "github.com/gomlx/gomlx/ml/data"
   "github.com/janpfeifer/gonb/gonbui"
)

var _ = NewGraph  // Avoid warnings of non-used import.

%% --set="image_size=128"
c := ConfigFromFlags()
mean, stddev := c.NormalizationValues()
fmt.Printf("Flower images, per channel (red, green, blue):\n\t  mean=%v\n\tstddev=%v\n", mean.Value(), stddev.Value())

trainDS, _ := c.CreateInMemoryDatasets()
trainDS.Shuffle()
trainDS.BatchSize(6, true)
_, inputs, _ := must.M3(trainDS.Yield())
gonbui.DisplayHTML("<p><b>Original:</b></p>")
diffusion.PlotImagesTensor(inputs[0])

e := NewExec(c.Backend, func(images *Node) *Node {
    images = c.PreprocessImages(images, true)
    images = c.DenormalizeImages(images)
    return images
})
gonbui.DisplayHTML("<p><b>After normalization and denormalization:</b></p>")
imagesT := e.Call(inputs[0])[0]
fmt.Printf("imagesT.shape=%s\n", imagesT.Shape())
diffusion.PlotImagesTensor(imagesT)
Flower images, per channel (red, green, blue):
	  mean=[[[[121.027176 100.22015 78.19373]]]]
	stddev=[[[[75.55288 62.183628 69.93133]]]]
Creating InMemoryDataset for "train" with images cropped and scaled to 128x128...
	- 4.74725111s to process dataset.
Creating InMemoryDataset for "validation" with images cropped and scaled to 128x128...
	- 1.636961804s to process dataset.

Original:

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

After normalization and denormalization:

imagesT.shape=(Float32)[6 128 128 3]
No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Sinusoidal Embedding¶

Used to embed the variance of the noise at different frequencies.

In [7]:
import (
    "github.com/gomlx/gomlx/examples/oxfordflowers102/diffusion"
    . "github.com/gomlx/gomlx/pkg/core/graph"
)

%%
c := ConfigFromFlags()
value := NewExec(c.Backend, func (x *Node) *Node {
    return diffusion.SinusoidalEmbedding(c.Context, x)
}).Call(float32(1.0))[0]
fmt.Printf("SinusoidalEmbedding(1.0)=\n\tShape: %s\n\tValue: %v\n", value.Shape(), value.Value())
SinusoidalEmbedding(1.0)=
	Shape: (Float32)[32]
	Value: [1.7484555e-07 -0.5084644 -0.074616365 -0.11864995 0.93075866 2.70213e-06 -0.8129459 0.6793376 -0.92810476 0.5659511 1.176251e-05 0.06701087 0.9265504 0.62322754 -0.26391345 0.0007279766 1 -0.86108303 -0.9972123 0.99293613 -0.3656342 1 0.5823393 0.73382586 0.37231916 0.8244388 1 -0.99775225 0.37617072 0.7820406 0.9645464 0.99999976]

U-Net Model¶

The code in diffusion.UNetModelGraph follows the Keras example's Network Architecture.

The following the modeling functions:

  • UNetModelGraph builds a noisy image to (predicted image, predicted noise) computation graph, using U-Net model, it's the core of this example.
  • BuildTrainingModelGraph builds the train.ModelFn (the function that GoMLX uses for a training loop). It takes raw images as examples, adds some random noise, at a random time (from 0.0 to 1.0), and uses the U-Net model to try to separate the noise. It returns the predicted image and the loss, where the loss is measured on the predicted noise -- it is better for learning than predicting the original image (*)

(*) My hypothesis is that predicting the original image is easier to overfit (since we have a limited number of images, but infinite noises we can generate).

The model in its default configuration uses ~3.5 million parameters:

In [8]:
import (
    . "github.com/gomlx/gomlx/pkg/core/graph"
    "github.com/gomlx/gomlx/ml/context"
    "github.com/gomlx/gomlx/types/shapes"
    "github.com/gomlx/gopjrt/dtypes"
)

// batch_size=5 just for testing.
%% --set="flower_type_embed_size=16;batch_size=5"
c := ConfigFromFlags()
fmt.Printf("Backend %q: %s\n", c.Backend.Name(), c.Backend.Description())

fmt.Println("\nUNetModelGraph:")
g := NewGraph(c.Backend, "test")
noisyImages := Zeros(g, shapes.Make(c.DType, c.BatchSize, 64, 64, 3))
flowerIds := Zeros(g, shapes.Make(dtypes.Int32, c.BatchSize))
fmt.Printf("  noisyImages.shape:\t%s\n", noisyImages.Shape())
filtered := diffusion.UNetModelGraph(c.Context, noisyImages, Ones(g, shapes.Make(c.DType, 5, 1, 1, 1)), flowerIds)
fmt.Printf("     filtered.shape:\t%s\n", filtered.Shape())
fmt.Printf("U-Net Model #params:\t%d\n", c.Context.NumParameters())
fmt.Printf(" U-Net Model memory:\t%s\n", data.ByteCountIEC(c.Context.Memory()))

fmt.Println("\nModelGraph:")
images := Zeros(g, shapes.Make(c.DType, 5, c.ImageSize, c.ImageSize, 3))
fmt.Printf("         images.shape:\t%s\n", images.Shape())
modelFn := c.BuildTrainingModelGraph()
predictions := modelFn(c.Context.Reuse(), nil, []*Node{images, nil, flowerIds})
fmt.Printf("predictedImages.shape:\t%s\n", predictions[0].Shape())
fmt.Printf("           loss.shape:\t%s\n", predictions[1].Shape())
fmt.Printf("        Model #params:\t%d\n", c.Context.NumParameters())
fmt.Printf("         Model memory:\t%s\n", data.ByteCountIEC(c.Context.Memory()))
Backend "xla": xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.54

UNetModelGraph:
  noisyImages.shape:	(Float32)[5 64 64 3]
     filtered.shape:	(Float32)[5 64 64 3]
U-Net Model #params:	4400646
 U-Net Model memory:	16.8 MiB

ModelGraph:
         images.shape:	(Float32)[5 64 64 3]
predictedImages.shape:	(Float32)[5 64 64 3]
           loss.shape:	(Float32)
        Model #params:	4400646
         Model memory:	16.8 MiB

Training Model¶

The training was mostly done from the command line -- easier to leave it running for hours -- using the train program.

It can be installed with go install github.com/gomlx/gomlx/examples/oxfordflowers102/diffusion/train@latest, but usually I would just go to the directory and do a go run . <...flags...>, see some examples below. It does also requires github.com/gomlx/gopjrt installed with the PJRT plugins for your device (CPU, GPU, etc.).

The train program (and the library function) saves evaluation points as it is training, and these can be plotted, see below for an example. If the training is interrupted and restarted, it continues where it left of.

See also github.com/gomlx/gomlx/cmd/gomlx_checkpoints to pretty-print the metrics, a model summary, its variables and its hyperparameters from the command line.

During the writing of this notebook, a few such models were generated, with different hyperparameters. But the space was not explored a lot -- if someone with more GPU available is willing to try some hyperaparameter tuning or have a better diffusion model to use, please share!

Training from the notebook¶

Because it takes many hours, we recommend training from the command line. But it can be done from the notebook as well -- it uses the same training function -- which can be useful for testing and development.

In [9]:
!rm -rf ~/work/oxfordflowers102/test/
In [10]:
%% --checkpoint "test" --set="train_steps=2000;plots=true;diffusion_num_residual_blocks=2"
c := ConfigFromFlags()
fmt.Printf("Backend %q: %s\n", c.Backend.Name(), c.Backend.Description())
diffusion.TrainModel(c.Context, *flagDataDir, *flagCheckpoint, c.ParamsSet, *flagEval, *flagVerbosity)
Backend "xla": xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.54
Backend "xla":	xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.54
	Checkpoint: "/home/janpf/work/oxfordflowers102/test"
	train_steps=2000
	plots=true
	diffusion_num_residual_blocks=2
Training (2000 steps):   36% [=============>..........................] (7 steps/s) [2m33s:2m53s] [step=727] [loss+=0.211] [~loss+=0.208] [~loss=0.208] [~img_loss=0.211] [~fast_loss=0.206] [~mae=0.206]        
Training (2000 steps):  100% [========================================] (6 steps/s) [step=1999] [loss+=0.214] [~loss+=0.195] [~loss=0.195] [~img_loss=0.203] [~fast_loss=0.194] [~mae=0.194]          
	[Step 2000] median train step: 132139 microseconds

Results on train:
	Mean Loss+Regularization (#loss+): 0.196
	Mean Loss (#loss): 0.196
	Images Loss (img_loss): 0.198
	MAE Loss (#mae): 0.196
Results on validation:
	Mean Loss+Regularization (#loss+): 0.197
	Mean Loss (#loss): 0.197
	Images Loss (img_loss): 0.194
	MAE Loss (#mae): 0.197

Generating images from a model¶

To generate images, we build and execute the model on various step. Each assuming less noise and more signal. It starts with purely randomly generated noise.

The function GenerateImages(numImages, numSteps, displayEveryNSteps) orchestrate this for us.

This only works with a trained model saved to a checkpoint.

The "use_ema=true" hyperparameter tells it to use an expoential moving average (ema) of the weights for evaluation (it requires the model to have been trained with it enabled).

In [11]:
%% --checkpoint="v2_64x64" -set="use_ema=true"
c := ConfigFromFlags()
_, _, _ = c.AttachCheckpoint(*flagCheckpoint)
c.DisplayImagesAcrossDiffusionSteps(12, 20, 10)
	Checkpoint: "/home/janpf/work/oxfordflowers102/v2_64x64"
DisplayImagesAcrossDiffusionSteps(12 images, 20 steps): noise.shape=(Float32)[12 64 64 3]
	Model #params:	6841702
	 Model memory:	26.1 MiB

Noise

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

5.0% Denoised -- Step 1/20

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

55.0% Denoised -- Step 11/20

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

100.0% Denoised -- Step 20/20

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Generating 102 Images For Each Flower Type, Same Noise¶

Notice our current model is not very good yet: it's +/- the same image.

In [12]:
%% --checkpoint="v2_64x64" --set="use_ema=false"
c := ConfigFromFlags()
_, _, _ = c.AttachCheckpoint(*flagCheckpoint)
diffusion.PlotImagesTensor(c.GenerateImagesOfAllFlowerTypes(20))
	Checkpoint: "/home/janpf/work/oxfordflowers102/v2_64x64"
No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Fixed starting noise, generated flower at different points in training¶

At the start of each new model, the diffusion.TrainModel() function saves a set of random noise, and every time that the model is evaluated, images at the corresponding checkpoint at generated using these fixed noise. This way one can see how the odel training evaluates.

Below is the list of images generated by our v5_64x64 at various steps during training:

In [13]:
%% --checkpoint="v2_64x64"
c := ConfigFromFlags()
_, _, _ = c.AttachCheckpoint(*flagCheckpoint)
c.PlotModelEvolution(20, /* animated */ true)
	Checkpoint: "/home/janpf/work/oxfordflowers102/v2_64x64"

Generated samples in /home/janpf/work/oxfordflowers102/v2_64x64:

Sample generated images per flower type¶

For a sample of flower types, generate 20 images for each.

Definitely, not as good as the one from the previous version, displayed at the bottom.

In [14]:
import flowers "github.com/gomlx/gomlx/examples/oxfordflowers102"

%% --checkpoint="v2_64x64"
c := ConfigFromFlags()
_, _, _ = c.AttachCheckpoint(*flagCheckpoint)
for ii := 0; ii < 5; ii++ {
    flowerType := int32(rand.Intn(flowers.NumLabels))
    gonbui.DisplayHTML(fmt.Sprintf("<p>Generated <b>%s</b></p>\n", flowers.Names[flowerType]))
    diffusion.PlotImagesTensor(c.GenerateImagesOfFlowerType(12, flowerType, 30))
}
	Checkpoint: "/home/janpf/work/oxfordflowers102/v2_64x64"

Generated pincushion flower

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Generated tree mallow

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Generated frangipani

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Generated morning glory

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Generated osteospermum

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Older Version¶

Below are results using an older version of GoMLX: the models trained then were better, but I didn't retrain them on the newer version.

TODO: improve them either with more training time, more hyperparameter tuning or a better model.

Larger 128x128 model with Transformer blocks¶

Belwo some random images generated by a model trained with 128x128, with more blocks and in the middle with 4 attention layers.

Also I tried the mean squared loss function.

In [22]:
%% --checkpoint="model_128x128_01" --size=128 --att_layers=4 --blocks=6 --norm=layer --activation=sigmoid --channels_list=16,32,64,96,128 --loss=mse --checkpoint_mean=-1
diffusion.PlotImagesTensor(diffusion.GenerateImages(90, 20, 0))
No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Model conditioned on flower types¶

The flag --flower_type_dim=16 will use the flower type (flowerIds) as a feature, and embed it with the given dimension at the start of each block.

We trained the model model_64x64_02 with 200K steps using this flag:

$ go run . --steps=200000 --plots --checkpoint="model_64x64_02" --norm=layer --learning_rate=1e-3 --flower_type_dim=16

Below the generation output for a few random examples for a few random flower types:

In [8]:
%% --checkpoint="model_64x64_02" --norm=layer --learning_rate=1e-3 --flower_type_dim=16  --checkpoint_mean=-1
for ii := 0; ii < 5; ii++ {
    flowerType := int32(rand.Intn(flowers.NumLabels))
    gonbui.DisplayHTML(fmt.Sprintf("<p>Generated <b>%s</b></p>\n", flowers.Names[flowerType]))
    diffusion.PlotImagesTensor(diffusion.GenerateImagesOfFlowerType(18, flowerType, 30))
}

Generated gazania

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Generated lotus

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Generated alpine sea holly

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Generated wallflower

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

Generated thorn apple

No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image

One starting noise, different flower types¶

In [9]:
%% --checkpoint="model_64x64_02" 
c := ConfigFromFlags()
_, _, _ = c.AttachCheckpoint(*flagCheckpoint)
diffusion.PlotImagesTensor(c.GenerateImagesOfAllFlowerTypes(20))
No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image No description has been provided for this image
In [ ]:
%% --checkpoint="model_64x64_02" --norm=layer --learning_rate=1e-3 --flower_type_dim=16  --checkpoint_mean=-1
// Load model
ctx := context.NewContext(manager).Checked(false)
_, _, _ = diffusion.LoadCheckpointToContext(ctx)
ctx.RngStateReset()

// Create UI with diffusion generated flowers.
divId := dom.CreateTransientDiv()
// cache.ResetKey("slider_diffusion_steps")
doneSteps := diffusion.SliderDiffusionSteps("slider_diffusion_steps", ctx, 8, 30, divId)
// cache.ResetKey("dropdown_flower_types")
doneFlowers := diffusion.DropdownFlowerTypes("dropdown_flower_types", ctx, 8, 20, divId)

// Wait for OK button.
button := widgets.Button("Ok").AppendTo(divId).Done()
<-button.Listen().C

// Clean up and persist HTML (so it can be saved).
doneSteps.Trigger()
doneFlowers.Trigger()
dom.Persist(divId)
In [7]:
%% --checkpoint="model_64x64_02" --norm=layer --learning_rate=1e-3 --flower_type_dim=16  --checkpoint_mean=-1
// Load model
ctx := context.NewContext(manager).Checked(false)
_, _, _ = diffusion.LoadCheckpointToContext(ctx)

// Create UI with diffusion generated flowers.
divId := dom.CreateTransientDiv()
diffusion.SliderDiffusionSteps("slider_diffusion_steps", ctx, 8, 30, divId)
diffusion.DropdownFlowerTypes("dropdown_flower_types", ctx, 8, 20, divId)

// Wait for OK button.
button := widgets.Button("Ok").AppendTo(divId).Done()
<-button.Listen().C